{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "license"
      },
      "source": [
        "##### *Copyright 2020 Google LLC*\n",
        "*Licensed under the Apache License, Version 2.0 (the \"License\")*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "rKwqeqWBXANA"
      },
      "outputs": [],
      "source": [
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hRTa3Ee15WsJ"
      },
      "source": [
        "# Retrain a classification model for Edge TPU using post-training quantization (with TF2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "TaX0smDP7xQY"
      },
      "source": [
        "In this tutorial, we'll use TensorFlow 2 to create an image classification model, train it with a flowers dataset, and convert it to TensorFlow Lite using post-training quantization. Finally, we compile it for compatibility with the Edge TPU (available in [Coral devices](https://coral.ai/products/)).\n",
        "\n",
        "The model is based on a pre-trained version of MobileNet V2. We'll start by retraining only the classification layers, reusing MobileNet's pre-trained feature extractor layers. Then we'll fine-tune the model by updating weights in some of the feature extractor layers. This type of transfer learning is much faster than training the entire model from scratch.\n",
        "\n",
        "Once it's trained, we'll use post-training quantization to convert all parameters to int8 format, which reduces the model size and increases inferencing speed. This format is also required for compatibility on the Edge TPU.\n",
        "\n",
        "For more information about how to create a model compatible with the Edge TPU, see the [documentation at coral.ai](https://coral.ai/docs/edgetpu/models-intro/).\n",
        "\n",
        "**Note:** This tutorial requires TensorFlow 2.3+ for full quantization, which currently does not work for all types of models. In particular, this tutorial expects a Keras-built model and this conversion strategy currently doesn't work with models imported from a frozen graph. (If you're using TF 1.x, see [the 1.x version of this tutorial](https://colab.research.google.com/github/google-coral/tutorials/blob/master/retrain_classification_ptq_tf1.ipynb).)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "viewin-badges"
      },
      "source": [
        "\u003ca href=\"https://colab.research.google.com/github/google-coral/tutorials/blob/master/retrain_classification_ptq_tf2.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open in Colab\"\u003e\u003c/a\u003e\n",
        "\u0026nbsp;\u0026nbsp;\u0026nbsp;\u0026nbsp;\n",
        "\u003ca href=\"https://github.com/google-coral/tutorials/blob/master/retrain_classification_ptq_tf2.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://img.shields.io/static/v1?logo=GitHub\u0026label=\u0026color=333333\u0026style=flat\u0026message=View%20on%20GitHub\" alt=\"View in GitHub\"\u003e\u003c/a\u003e\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "BnSreNhbCQ69"
      },
      "source": [
        "To start running all the code in this tutorial, select **Runtime \u003e Run all** in the Colab toolbar."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GTCYQg_be8C0"
      },
      "source": [
        "## Import the required libraries"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "02MxhCyFmpzn"
      },
      "source": [
        "In order to quantize both the input and output tensors, we need `TFLiteConverter` APIs that are available in TensorFlow r2.3 or higher:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "iBMcobPHdD8O"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "assert float(tf.__version__[:3]) \u003e= 2.3\n",
        "\n",
        "import os\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "v77rlkCKW0IJ"
      },
      "source": [
        "## Prepare the training data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "j4QOy2uA3P_p"
      },
      "source": [
        "First let's download and organize the flowers dataset we'll use to retrain the model (it contains 5 flower classes).\n",
        "\n",
        "Pay attention to this part so you can reproduce it with your own images dataset. In particular, notice that the \"flower_photos\" directory contains an appropriately-named directory for each class. The following code randomizes and divides up the photos into training and validation sets, and generates a labels file based on the photo folder names."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "xxL2mjVVGIrV"
      },
      "outputs": [],
      "source": [
        "_URL = \"https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\"\n",
        "\n",
        "zip_file = tf.keras.utils.get_file(origin=_URL, \n",
        "                                   fname=\"flower_photos.tgz\", \n",
        "                                   extract=True)\n",
        "\n",
        "flowers_dir = os.path.join(os.path.dirname(zip_file), 'flower_photos')\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "z4gTv7ig2vMh"
      },
      "source": [
        "Next, we use [`ImageDataGenerator`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator) to rescale the image data into float values (divide by 255 so the tensor values are between 0 and 1), and call `flow_from_directory()` to create two generators: one for the training dataset and one for the validation dataset.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "aCLb_yV5JfF3"
      },
      "outputs": [],
      "source": [
        "IMAGE_SIZE = 224\n",
        "BATCH_SIZE = 64\n",
        "\n",
        "datagen = tf.keras.preprocessing.image.ImageDataGenerator(\n",
        "    rescale=1./255, \n",
        "    validation_split=0.2)\n",
        "\n",
        "train_generator = datagen.flow_from_directory(\n",
        "    flowers_dir,\n",
        "    target_size=(IMAGE_SIZE, IMAGE_SIZE),\n",
        "    batch_size=BATCH_SIZE, \n",
        "    subset='training')\n",
        "\n",
        "val_generator = datagen.flow_from_directory(\n",
        "    flowers_dir,\n",
        "    target_size=(IMAGE_SIZE, IMAGE_SIZE),\n",
        "    batch_size=BATCH_SIZE, \n",
        "    subset='validation')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "VePDZC5Bh2mO"
      },
      "source": [
        "\n",
        "On each iteration, these generators provide a batch of images by reading images from disk and processing them to the proper tensor size (224 x 224). The output is a tuple of (images, labels). For example, you can see the shapes here:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "tx1L7fxxWA_G"
      },
      "outputs": [],
      "source": [
        "image_batch, label_batch = next(val_generator)\n",
        "image_batch.shape, label_batch.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZrFFcwUb3iK9"
      },
      "source": [
        "Now save the class labels to a text file:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "-QFZIhWs4dsq"
      },
      "outputs": [],
      "source": [
        "print (train_generator.class_indices)\n",
        "\n",
        "labels = '\\n'.join(sorted(train_generator.class_indices.keys()))\n",
        "\n",
        "with open('flower_labels.txt', 'w') as f:\n",
        "  f.write(labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "duxD_UDSOmng"
      },
      "outputs": [],
      "source": [
        "!cat flower_labels.txt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "OkH-kazQecHB"
      },
      "source": [
        "## Build the model\n",
        "\n",
        "Now we'll create a model that's capable of transfer learning on just the last fully-connected layer. \n",
        "\n",
        "We'll start with MobileNet V2 from Keras as the base model, which is pre-trained with the ImageNet dataset (trained to recognize 1,000 classes). This provides us a great feature extractor for image classification and we can then train a new classification layer with our flowers dataset.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "AtYKxmW4kS-D"
      },
      "source": [
        "### Create the base model \n",
        "\n",
        "When instantiating the `MobileNetV2` model, we specify the `include_top=False` argument in order to load the network *without* the classification layers at the top. Then we set `trainable` false to freeze all the weights in the base model. This effectively converts the model into a feature extractor because all the pre-trained weights and biases are preserved in the lower layers when we begin training for our classification head."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "19IQ2gqneqmS"
      },
      "outputs": [],
      "source": [
        "IMG_SHAPE = (IMAGE_SIZE, IMAGE_SIZE, 3)\n",
        "\n",
        "# Create the base model from the pre-trained MobileNet V2\n",
        "base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,\n",
        "                                              include_top=False, \n",
        "                                              weights='imagenet')\n",
        "base_model.trainable = False"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wdMRM8YModbk"
      },
      "source": [
        "### Add a classification head\n",
        "\n",
        "Now we create a new [`Sequential`](https://www.tensorflow.org/api_docs/python/tf/keras/Sequential) model and pass the frozen MobileNet model as the base of the graph, and append new classification layers so we can set the final output dimension to match the number of classes in our dataset (5 types of flowers)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "eApvroIyn1K0"
      },
      "outputs": [],
      "source": [
        "model = tf.keras.Sequential([\n",
        "  base_model,\n",
        "  tf.keras.layers.Conv2D(filters=32, kernel_size=3, activation='relu'),\n",
        "  tf.keras.layers.Dropout(0.2),\n",
        "  tf.keras.layers.GlobalAveragePooling2D(),\n",
        "  tf.keras.layers.Dense(units=5, activation='softmax')\n",
        "])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "g0ylJXE_kRLi"
      },
      "source": [
        "### Configure the model\n",
        "\n",
        "Although this method is called `compile()`, it's basically a configuration step that's required before we can start training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "RpR8HdyMhukJ"
      },
      "outputs": [],
      "source": [
        "model.compile(optimizer='adam', \n",
        "              loss='categorical_crossentropy', \n",
        "              metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YI-FWMqYlm1X"
      },
      "source": [
        "You can see a string summary of the final network with the `summary()` method:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "I8ARiyMFsgbH"
      },
      "outputs": [],
      "source": [
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "OjgKQrUem04S"
      },
      "source": [
        "And because the majority of the model graph is frozen in the base model, weights from only the last convolution and dense layers are trainable:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "krvBumovycVA"
      },
      "outputs": [],
      "source": [
        "print('Number of trainable weights = {}'.format(len(model.trainable_weights)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RxvgOYTDSWTx"
      },
      "source": [
        "## Train the model\n",
        "\n",
        "\u003c!-- TODO(markdaoust): delete steps_per_epoch in TensorFlow r1.14/r2.0 --\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kBRNaOCCoA-P"
      },
      "source": [
        "Now we can train the model using data provided by the `train_generator` and `val_generator` that we created at the beginning. \n",
        "\n",
        "This should take less than 10 minutes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "JsaRFlZ9B6WK"
      },
      "outputs": [],
      "source": [
        "history = model.fit(train_generator,\n",
        "                    steps_per_epoch=len(train_generator), \n",
        "                    epochs=10,\n",
        "                    validation_data=val_generator,\n",
        "                    validation_steps=len(val_generator))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Hd94CKImf8vi"
      },
      "source": [
        "### Review the learning curves\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "53OTCh3jnbwV"
      },
      "outputs": [],
      "source": [
        "acc = history.history['accuracy']\n",
        "val_acc = history.history['val_accuracy']\n",
        "\n",
        "loss = history.history['loss']\n",
        "val_loss = history.history['val_loss']\n",
        "\n",
        "plt.figure(figsize=(8, 8))\n",
        "plt.subplot(2, 1, 1)\n",
        "plt.plot(acc, label='Training Accuracy')\n",
        "plt.plot(val_acc, label='Validation Accuracy')\n",
        "plt.legend(loc='lower right')\n",
        "plt.ylabel('Accuracy')\n",
        "plt.ylim([min(plt.ylim()),1])\n",
        "plt.title('Training and Validation Accuracy')\n",
        "\n",
        "plt.subplot(2, 1, 2)\n",
        "plt.plot(loss, label='Training Loss')\n",
        "plt.plot(val_loss, label='Validation Loss')\n",
        "plt.legend(loc='upper right')\n",
        "plt.ylabel('Cross Entropy')\n",
        "plt.ylim([0,1.0])\n",
        "plt.title('Training and Validation Loss')\n",
        "plt.xlabel('epoch')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CqwV-CRdS6Nv"
      },
      "source": [
        "## Fine tune the base model\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "dBTEEnxv9X6J"
      },
      "source": [
        "So far, we've only trained the classification layers—the weights of the pre-trained network were *not* changed.\n",
        "\n",
        "One way we can increase the accuracy is to train (or \"fine-tune\") more layers from the pre-trained model. That is, we'll un-freeze some layers from the base model and adjust those weights (which were originally trained with 1,000 ImageNet classes) so they're better tuned for features found in our flowers dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CPXnzUK0QonF"
      },
      "source": [
        "### Un-freeze more layers\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rfxv_ifotQak"
      },
      "source": [
        "So instead of freezing the entire base model, we'll freeze individual layers.\n",
        "\n",
        "First, let's see how many layers are in the base model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "4nzcagVitLQm"
      },
      "outputs": [],
      "source": [
        "print(\"Number of layers in the base model: \", len(base_model.layers))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "dGcXdaQqASlC"
      },
      "source": [
        "Let's try freezing just the bottom 100 layers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "-4HgVAacRs5v"
      },
      "outputs": [],
      "source": [
        "base_model.trainable = True\n",
        "fine_tune_at = 100\n",
        "\n",
        "# Freeze all the layers before the `fine_tune_at` layer\n",
        "for layer in base_model.layers[:fine_tune_at]:\n",
        "  layer.trainable =  False"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "4Uk1dgsxT0IS"
      },
      "source": [
        "### Reconfigure the model\n",
        "\n",
        "Now configure the model again, but this time with a lower learning rate (the default is 0.001)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "NtUnaz0WUDva"
      },
      "outputs": [],
      "source": [
        "model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),\n",
        "              loss='categorical_crossentropy',\n",
        "              metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "WwBWy7J2kZvA"
      },
      "outputs": [],
      "source": [
        "model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "bNXelbMQtonr"
      },
      "outputs": [],
      "source": [
        "print('Number of trainable weights = {}'.format(len(model.trainable_weights)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "4G5O4jd6TuAG"
      },
      "source": [
        "### Continue training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "bppmJTmDpXtK"
      },
      "source": [
        "Now let's fine-tune all trainable layers. This starts with the weights we already trained in the classification layers, so we don't need as many epochs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "PiXbLb1O8IDy"
      },
      "outputs": [],
      "source": [
        "history_fine = model.fit(train_generator,\n",
        "                         steps_per_epoch=len(train_generator), \n",
        "                         epochs=5,\n",
        "                         validation_data=val_generator,\n",
        "                         validation_steps=len(val_generator))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xqIjZvhBBJNn"
      },
      "source": [
        "### Review the new learning curves"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "chW103JUItdk"
      },
      "outputs": [],
      "source": [
        "acc = history_fine.history['accuracy']\n",
        "val_acc = history_fine.history['val_accuracy']\n",
        "\n",
        "loss = history_fine.history['loss']\n",
        "val_loss = history_fine.history['val_loss']\n",
        "\n",
        "plt.figure(figsize=(8, 8))\n",
        "plt.subplot(2, 1, 1)\n",
        "plt.plot(acc, label='Training Accuracy')\n",
        "plt.plot(val_acc, label='Validation Accuracy')\n",
        "plt.legend(loc='lower right')\n",
        "plt.ylabel('Accuracy')\n",
        "plt.ylim([min(plt.ylim()),1])\n",
        "plt.title('Training and Validation Accuracy')\n",
        "\n",
        "plt.subplot(2, 1, 2)\n",
        "plt.plot(loss, label='Training Loss')\n",
        "plt.plot(val_loss, label='Validation Loss')\n",
        "plt.legend(loc='upper right')\n",
        "plt.ylabel('Cross Entropy')\n",
        "plt.ylim([0,1.0])\n",
        "plt.title('Training and Validation Loss')\n",
        "plt.xlabel('epoch')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "TjdOw_xkqzI4"
      },
      "source": [
        "This is better, but it's not ideal.\n",
        "\n",
        "The validation loss is still higher than the training loss, so there could be some overfitting during training. The overfitting might also be because the new training set is relatively small with less intra-class variance, compared to the original ImageNet dataset used to train MobileNet V2.\n",
        "\n",
        "So this model isn't trained to an accuracy that's production ready, but it works well enough as a demonstration.\n",
        "\n",
        "Let's move on and convert the model to TensorFlow Lite."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kRDabW_u1wnv"
      },
      "source": [
        "## Convert to TFLite"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hNvMl6CM6lG4"
      },
      "source": [
        "Ordinarily, creating a TensorFlow Lite model is just a few lines of code with [`TFLiteConverter`](https://www.tensorflow.org/api_docs/python/tf/lite/TFLiteConverter). For example, this creates a basic (un-quantized) TensorFlow Lite model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "srOYhMYfx9XH"
      },
      "outputs": [],
      "source": [
        "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
        "tflite_model = converter.convert()\n",
        "\n",
        "with open('mobilenet_v2_1.0_224.tflite', 'wb') as f:\n",
        "  f.write(tflite_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "P0_StAtZwJ5p"
      },
      "source": [
        "However, this `.tflite` file still uses floating-point values for the parameter data, and we need to fully quantize the model to int8 format.\n",
        "\n",
        "To fully quantize the model, we need to perform [post-training quantization](https://www.tensorflow.org/lite/performance/post_training_quantization) with a representative dataset, which requires a few more arguments for the `TFLiteConverter`, and a function that builds a dataset that's representative of the training dataset. \n",
        "\n",
        "So let's convert the model again with post-training quantization:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "w9ydAmHGHUZl"
      },
      "outputs": [],
      "source": [
        "# A generator that provides a representative dataset\n",
        "def representative_data_gen():\n",
        "  dataset_list = tf.data.Dataset.list_files(flowers_dir + '/*/*')\n",
        "  for i in range(100):\n",
        "    image = next(iter(dataset_list))\n",
        "    image = tf.io.read_file(image)\n",
        "    image = tf.io.decode_jpeg(image, channels=3)\n",
        "    image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])\n",
        "    image = tf.cast(image / 255., tf.float32)\n",
        "    image = tf.expand_dims(image, 0)\n",
        "    yield [image]\n",
        "\n",
        "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
        "# This enables quantization\n",
        "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
        "# This sets the representative dataset for quantization\n",
        "converter.representative_dataset = representative_data_gen\n",
        "# This ensures that if any ops can't be quantized, the converter throws an error\n",
        "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]\n",
        "# For full integer quantization, though supported types defaults to int8 only, we explicitly declare it for clarity.\n",
        "converter.target_spec.supported_types = [tf.int8]\n",
        "# These set the input and output tensors to uint8 (added in r2.3)\n",
        "converter.inference_input_type = tf.uint8\n",
        "converter.inference_output_type = tf.uint8\n",
        "tflite_model = converter.convert()\n",
        "\n",
        "with open('mobilenet_v2_1.0_224_quant.tflite', 'wb') as f:\n",
        "  f.write(tflite_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RMLYBDe_e849"
      },
      "source": [
        "### Compare the accuracy\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "SFgbRx_Twd-P"
      },
      "source": [
        "So now we have a fully quantized TensorFlow Lite model. To be sure the conversion went well, let's evaluate both the raw model and the TensorFlow Lite model.\n",
        "\n",
        "First check the accuracy of the raw model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "RkQ2IlAWfC5O"
      },
      "outputs": [],
      "source": [
        "batch_images, batch_labels = next(val_generator)\n",
        "\n",
        "logits = model(batch_images)\n",
        "prediction = np.argmax(logits, axis=1)\n",
        "truth = np.argmax(batch_labels, axis=1)\n",
        "\n",
        "keras_accuracy = tf.keras.metrics.Accuracy()\n",
        "keras_accuracy(prediction, truth)\n",
        "\n",
        "print(\"Raw model accuracy: {:.3%}\".format(keras_accuracy.result()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Hjx3dgZNwmKN"
      },
      "source": [
        "Now let's check the accuracy of the `.tflite` file, using the same dataset.\n",
        "\n",
        "However, there's no convenient API to evaluate the accuracy of a TensorFlow Lite model, so this code runs several inferences and compares the predictions against ground truth:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "iBs0O7q_wlCN"
      },
      "outputs": [],
      "source": [
        "def set_input_tensor(interpreter, input):\n",
        "  input_details = interpreter.get_input_details()[0]\n",
        "  tensor_index = input_details['index']\n",
        "  input_tensor = interpreter.tensor(tensor_index)()[0]\n",
        "  # Inputs for the TFLite model must be uint8, so we quantize our input data.\n",
        "  # NOTE: This step is necessary only because we're receiving input data from\n",
        "  # ImageDataGenerator, which rescaled all image data to float [0,1]. When using\n",
        "  # bitmap inputs, they're already uint8 [0,255] so this can be replaced with:\n",
        "  #   input_tensor[:, :] = input\n",
        "  scale, zero_point = input_details['quantization']\n",
        "  input_tensor[:, :] = np.uint8(input / scale + zero_point)\n",
        "\n",
        "def classify_image(interpreter, input):\n",
        "  set_input_tensor(interpreter, input)\n",
        "  interpreter.invoke()\n",
        "  output_details = interpreter.get_output_details()[0]\n",
        "  output = interpreter.get_tensor(output_details['index'])\n",
        "  # Outputs from the TFLite model are uint8, so we dequantize the results:\n",
        "  scale, zero_point = output_details['quantization']\n",
        "  output = scale * (output - zero_point)\n",
        "  top_1 = np.argmax(output)\n",
        "  return top_1\n",
        "\n",
        "interpreter = tf.lite.Interpreter('mobilenet_v2_1.0_224_quant.tflite')\n",
        "interpreter.allocate_tensors()\n",
        "\n",
        "# Collect all inference predictions in a list\n",
        "batch_prediction = []\n",
        "batch_truth = np.argmax(batch_labels, axis=1)\n",
        "\n",
        "for i in range(len(batch_images)):\n",
        "  prediction = classify_image(interpreter, batch_images[i])\n",
        "  batch_prediction.append(prediction)\n",
        "\n",
        "# Compare all predictions to the ground truth\n",
        "tflite_accuracy = tf.keras.metrics.Accuracy()\n",
        "tflite_accuracy(batch_prediction, batch_truth)\n",
        "print(\"Quant TF Lite accuracy: {:.3%}\".format(tflite_accuracy.result()))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "WfM4kAPiPg9q"
      },
      "source": [
        "You might see some, but hopefully not very much accuracy drop between the raw model and the TensorFlow Lite model. But again, these results are not suitable for production deployment."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZmiHICezwXZq"
      },
      "source": [
        "## Compile for the Edge TPU\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DhOzAdzF3Dyk"
      },
      "source": [
        "Finally, we're ready to compile the model for the Edge TPU.\n",
        "\n",
        "First download the [Edge TPU Compiler](https://coral.ai/docs/edgetpu/compiler/):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "p6ZpWgrk21Ad"
      },
      "outputs": [],
      "source": [
        "! curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -\n",
        "\n",
        "! echo \"deb https://packages.cloud.google.com/apt coral-edgetpu-stable main\" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list\n",
        "\n",
        "! sudo apt-get update\n",
        "\n",
        "! sudo apt-get install edgetpu-compiler\t"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "mtPcYiER3Ymp"
      },
      "source": [
        "Then compile the model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "joxrIB0I3cdi"
      },
      "outputs": [],
      "source": [
        "! edgetpu_compiler mobilenet_v2_1.0_224_quant.tflite"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7R8JMQc1MMm5"
      },
      "source": [
        "That's it.\n",
        "\n",
        "The compiled model uses the same filename but with \"_edgetpu\" appended at the end."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Oi9-Voc8A7VK"
      },
      "source": [
        "## Download the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "XiugMm-jBbWl"
      },
      "source": [
        "You can download the converted model and labels file from Colab like this: "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "x47uW_lI1DoV"
      },
      "outputs": [],
      "source": [
        "from google.colab import files\n",
        "\n",
        "files.download('mobilenet_v2_1.0_224_quant_edgetpu.tflite')\n",
        "files.download('flower_labels.txt')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_qOCP3mXXvsm"
      },
      "source": [
        "If you get a \"Failed to fetch\" error here, it's probably because the files weren't done saving. So just wait a moment and try again.\n",
        "\n",
        "Also look out for a browser popup that might need approval to download the files."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_TZTwG7nhm0C"
      },
      "source": [
        "## Run the model on the Edge TPU\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RwywT4ZpQjLf"
      },
      "source": [
        "You can now run the model on your Coral device with acceleration on the Edge TPU.\n",
        "\n",
        "To get started, try using your `.tflite` model with [this code for image classification with the TensorFlow Lite API](https://github.com/google-coral/tflite/tree/master/python/examples/classification). \n",
        "\n",
        "Just follow the instructions on that page to set up your device, copy the `mobilenet_v2_1.0_224_quant_edgetpu.tflite` and `flower_labels.txt` files to your Coral Dev Board or device with a Coral Accelerator, and pass it a flower photo like this:\n",
        "\n",
        "```\n",
        "python3 classify_image.py \\\n",
        "  --model mobilenet_v2_1.0_224_quant_edgetpu.tflite \\\n",
        "  --labels flower_labels.txt \\\n",
        "  --input flower.jpg\n",
        "```\n",
        "\n",
        "Check out more examples for running inference at [coral.ai/examples](https://coral.ai/examples/#code-examples/)."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "retrain_classification_ptq_tf2.ipynb",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
